{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 2SLS Benchmarks with NLSYM + Synthetic Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "We demonstrate the use of 2SLS from the package to estimate the average treatment effect by semi-synthetic data and full synthetic data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T18:34:07.556482Z",
     "start_time": "2020-06-22T18:34:07.075342Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T18:34:08.026503Z",
     "start_time": "2020-06-22T18:34:07.998418Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "base_path = os.path.abspath(\"../\")\n",
    "os.chdir(base_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:20:57.201672Z",
     "start_time": "2020-06-22T19:20:57.102105Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sys\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:04:33.015381Z",
     "start_time": "2020-06-22T19:04:32.964582Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import causalml\n",
    "from causalml.inference.iv import IVRegressor\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import statsmodels.api as sm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Semi-Synthetic Data from NLSYM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:32:40.653280Z",
     "start_time": "2020-06-22T19:32:40.595806Z"
    }
   },
   "source": [
    "The data generation mechanism is described in Syrgkanis et al \"*Machine Learning Estimation of Heterogeneous Treatment Effects with Instruments*\" (2019)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T18:34:16.787310Z",
     "start_time": "2020-06-22T18:34:16.720144Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"docs/examples/data/card.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T18:34:17.310674Z",
     "start_time": "2020-06-22T18:34:17.231429Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>nearc2</th>\n",
       "      <th>nearc4</th>\n",
       "      <th>educ</th>\n",
       "      <th>age</th>\n",
       "      <th>fatheduc</th>\n",
       "      <th>motheduc</th>\n",
       "      <th>weight</th>\n",
       "      <th>momdad14</th>\n",
       "      <th>sinmom14</th>\n",
       "      <th>...</th>\n",
       "      <th>smsa66</th>\n",
       "      <th>wage</th>\n",
       "      <th>enroll</th>\n",
       "      <th>kww</th>\n",
       "      <th>iq</th>\n",
       "      <th>married</th>\n",
       "      <th>libcrd14</th>\n",
       "      <th>exper</th>\n",
       "      <th>lwage</th>\n",
       "      <th>expersq</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>29</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>158413</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>548</td>\n",
       "      <td>0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>16</td>\n",
       "      <td>6.306275</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>12</td>\n",
       "      <td>27</td>\n",
       "      <td>8.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>380166</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>481</td>\n",
       "      <td>0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>93.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9</td>\n",
       "      <td>6.175867</td>\n",
       "      <td>81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>12</td>\n",
       "      <td>34</td>\n",
       "      <td>14.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>367470</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>721</td>\n",
       "      <td>0</td>\n",
       "      <td>42.0</td>\n",
       "      <td>103.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>16</td>\n",
       "      <td>6.580639</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>11</td>\n",
       "      <td>27</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>380166</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>250</td>\n",
       "      <td>0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>88.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10</td>\n",
       "      <td>5.521461</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>12</td>\n",
       "      <td>34</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>367470</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>729</td>\n",
       "      <td>0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>108.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>16</td>\n",
       "      <td>6.591674</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 34 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  nearc2  nearc4  educ  age  fatheduc  motheduc  weight  momdad14  \\\n",
       "0   2       0       0     7   29       NaN       NaN  158413         1   \n",
       "1   3       0       0    12   27       8.0       8.0  380166         1   \n",
       "2   4       0       0    12   34      14.0      12.0  367470         1   \n",
       "3   5       1       1    11   27      11.0      12.0  380166         1   \n",
       "4   6       1       1    12   34       8.0       7.0  367470         1   \n",
       "\n",
       "   sinmom14  ...  smsa66  wage  enroll   kww     iq  married  libcrd14  exper  \\\n",
       "0         0  ...       1   548       0  15.0    NaN      1.0       0.0     16   \n",
       "1         0  ...       1   481       0  35.0   93.0      1.0       1.0      9   \n",
       "2         0  ...       1   721       0  42.0  103.0      1.0       1.0     16   \n",
       "3         0  ...       1   250       0  25.0   88.0      1.0       1.0     10   \n",
       "4         0  ...       1   729       0  34.0  108.0      1.0       0.0     16   \n",
       "\n",
       "      lwage  expersq  \n",
       "0  6.306275      256  \n",
       "1  6.175867       81  \n",
       "2  6.580639      256  \n",
       "3  5.521461      100  \n",
       "4  6.591674      256  \n",
       "\n",
       "[5 rows x 34 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T18:34:17.900311Z",
     "start_time": "2020-06-22T18:34:17.855260Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['id', 'nearc2', 'nearc4', 'educ', 'age', 'fatheduc', 'motheduc',\n",
       "       'weight', 'momdad14', 'sinmom14', 'step14', 'reg661', 'reg662',\n",
       "       'reg663', 'reg664', 'reg665', 'reg666', 'reg667', 'reg668',\n",
       "       'reg669', 'south66', 'black', 'smsa', 'south', 'smsa66', 'wage',\n",
       "       'enroll', 'kww', 'iq', 'married', 'libcrd14', 'exper', 'lwage',\n",
       "       'expersq'], dtype=object)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns.values\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:02:35.660584Z",
     "start_time": "2020-06-22T19:02:35.515002Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data_filter = df['educ'] >= 6\n",
    "# outcome variable\n",
    "y=df[data_filter]['lwage'].values\n",
    "# treatment variable\n",
    "treatment=df[data_filter]['educ'].values\n",
    "# instrumental variable\n",
    "w=df[data_filter]['nearc4'].values\n",
    "\n",
    "Xdf=df[data_filter][['fatheduc', 'motheduc', 'momdad14', 'sinmom14', 'reg661', 'reg662',\n",
    "      'reg663', 'reg664', 'reg665', 'reg666', 'reg667', 'reg668',\n",
    "      'reg669', 'south66', 'black', 'smsa', 'south', 'smsa66',\n",
    "      'exper', 'expersq']]\n",
    "Xdf['fatheduc']=Xdf['fatheduc'].fillna(value=Xdf['fatheduc'].mean())\n",
    "Xdf['motheduc']=Xdf['motheduc'].fillna(value=Xdf['motheduc'].mean())\n",
    "Xscale=Xdf.copy()\n",
    "Xscale[['fatheduc', 'motheduc', 'exper', 'expersq']]=StandardScaler().fit_transform(Xscale[['fatheduc', 'motheduc', 'exper', 'expersq']])\n",
    "X=Xscale.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:02:47.412463Z",
     "start_time": "2020-06-22T19:02:47.251852Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fatheduc</th>\n",
       "      <th>motheduc</th>\n",
       "      <th>momdad14</th>\n",
       "      <th>sinmom14</th>\n",
       "      <th>reg661</th>\n",
       "      <th>reg662</th>\n",
       "      <th>reg663</th>\n",
       "      <th>reg664</th>\n",
       "      <th>reg665</th>\n",
       "      <th>reg666</th>\n",
       "      <th>reg667</th>\n",
       "      <th>reg668</th>\n",
       "      <th>reg669</th>\n",
       "      <th>south66</th>\n",
       "      <th>black</th>\n",
       "      <th>smsa</th>\n",
       "      <th>south</th>\n",
       "      <th>smsa66</th>\n",
       "      <th>exper</th>\n",
       "      <th>expersq</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>2.991000e+03</td>\n",
       "      <td>2.991000e+03</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2991.000000</td>\n",
       "      <td>2.991000e+03</td>\n",
       "      <td>2.991000e+03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>-3.529069e-16</td>\n",
       "      <td>-1.704346e-15</td>\n",
       "      <td>0.790371</td>\n",
       "      <td>0.100301</td>\n",
       "      <td>0.046807</td>\n",
       "      <td>0.161484</td>\n",
       "      <td>0.196924</td>\n",
       "      <td>0.064527</td>\n",
       "      <td>0.205951</td>\n",
       "      <td>0.094952</td>\n",
       "      <td>0.109997</td>\n",
       "      <td>0.028419</td>\n",
       "      <td>0.090939</td>\n",
       "      <td>0.410899</td>\n",
       "      <td>0.231361</td>\n",
       "      <td>0.715145</td>\n",
       "      <td>0.400201</td>\n",
       "      <td>0.651622</td>\n",
       "      <td>4.285921e-16</td>\n",
       "      <td>3.040029e-17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>1.000167e+00</td>\n",
       "      <td>1.000167e+00</td>\n",
       "      <td>0.407112</td>\n",
       "      <td>0.300451</td>\n",
       "      <td>0.211261</td>\n",
       "      <td>0.368039</td>\n",
       "      <td>0.397741</td>\n",
       "      <td>0.245730</td>\n",
       "      <td>0.404463</td>\n",
       "      <td>0.293197</td>\n",
       "      <td>0.312938</td>\n",
       "      <td>0.166193</td>\n",
       "      <td>0.287571</td>\n",
       "      <td>0.492079</td>\n",
       "      <td>0.421773</td>\n",
       "      <td>0.451421</td>\n",
       "      <td>0.490021</td>\n",
       "      <td>0.476536</td>\n",
       "      <td>1.000167e+00</td>\n",
       "      <td>1.000167e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>-3.101056e+00</td>\n",
       "      <td>-3.502453e+00</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-2.159127e+00</td>\n",
       "      <td>-1.147691e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>-6.303764e-01</td>\n",
       "      <td>-4.656485e-01</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-6.858865e-01</td>\n",
       "      <td>-7.077287e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>2.091970e-01</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-1.948066e-01</td>\n",
       "      <td>-3.655360e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>6.049634e-01</td>\n",
       "      <td>5.466197e-01</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>5.418134e-01</td>\n",
       "      <td>3.310707e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>2.457973e+00</td>\n",
       "      <td>2.571156e+00</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.242753e+00</td>\n",
       "      <td>4.767355e+00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           fatheduc      motheduc     momdad14     sinmom14       reg661  \\\n",
       "count  2.991000e+03  2.991000e+03  2991.000000  2991.000000  2991.000000   \n",
       "mean  -3.529069e-16 -1.704346e-15     0.790371     0.100301     0.046807   \n",
       "std    1.000167e+00  1.000167e+00     0.407112     0.300451     0.211261   \n",
       "min   -3.101056e+00 -3.502453e+00     0.000000     0.000000     0.000000   \n",
       "25%   -6.303764e-01 -4.656485e-01     1.000000     0.000000     0.000000   \n",
       "50%    0.000000e+00  2.091970e-01     1.000000     0.000000     0.000000   \n",
       "75%    6.049634e-01  5.466197e-01     1.000000     0.000000     0.000000   \n",
       "max    2.457973e+00  2.571156e+00     1.000000     1.000000     1.000000   \n",
       "\n",
       "            reg662       reg663       reg664       reg665       reg666  \\\n",
       "count  2991.000000  2991.000000  2991.000000  2991.000000  2991.000000   \n",
       "mean      0.161484     0.196924     0.064527     0.205951     0.094952   \n",
       "std       0.368039     0.397741     0.245730     0.404463     0.293197   \n",
       "min       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "25%       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "50%       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "75%       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "max       1.000000     1.000000     1.000000     1.000000     1.000000   \n",
       "\n",
       "            reg667       reg668       reg669      south66        black  \\\n",
       "count  2991.000000  2991.000000  2991.000000  2991.000000  2991.000000   \n",
       "mean      0.109997     0.028419     0.090939     0.410899     0.231361   \n",
       "std       0.312938     0.166193     0.287571     0.492079     0.421773   \n",
       "min       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "25%       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "50%       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "75%       0.000000     0.000000     0.000000     1.000000     0.000000   \n",
       "max       1.000000     1.000000     1.000000     1.000000     1.000000   \n",
       "\n",
       "              smsa        south       smsa66         exper       expersq  \n",
       "count  2991.000000  2991.000000  2991.000000  2.991000e+03  2.991000e+03  \n",
       "mean      0.715145     0.400201     0.651622  4.285921e-16  3.040029e-17  \n",
       "std       0.451421     0.490021     0.476536  1.000167e+00  1.000167e+00  \n",
       "min       0.000000     0.000000     0.000000 -2.159127e+00 -1.147691e+00  \n",
       "25%       0.000000     0.000000     0.000000 -6.858865e-01 -7.077287e-01  \n",
       "50%       1.000000     0.000000     1.000000 -1.948066e-01 -3.655360e-01  \n",
       "75%       1.000000     1.000000     1.000000  5.418134e-01  3.310707e-01  \n",
       "max       1.000000     1.000000     1.000000  3.242753e+00  4.767355e+00  "
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xscale.describe()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Semi-Synthetic Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:00:36.976866Z",
     "start_time": "2020-06-22T19:00:36.929353Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def semi_synth_nlsym(X, w, random_seed=None):\n",
    "    np.random.seed(random_seed)\n",
    "    nobs = X.shape[0]\n",
    "    nv = np.random.uniform(0, 1, size=nobs)\n",
    "    c0 = np.random.uniform(0.2, 0.3)\n",
    "    C = c0 * X[:,1]\n",
    "    # Treatment compliance depends on mother education\n",
    "    treatment = C * w + X[:,1] + nv\n",
    "    # Treatment effect depends no mother education and single-mom family at age 14\n",
    "    theta = 0.1 + 0.05 * X[:,1] - 0.1*X[:,3]\n",
    "    # Additional effect on the outcome from mother education\n",
    "    f = 0.05 * X[:,1]\n",
    "    y = theta * (treatment + nv) + f + np.random.normal(0, 0.1, size=nobs)\n",
    "    \n",
    "    return y, treatment, theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:02:58.593481Z",
     "start_time": "2020-06-22T19:02:58.548556Z"
    }
   },
   "outputs": [],
   "source": [
    "y_sim, treatment_sim, theta = semi_synth_nlsym(Xdf.values, w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:03:22.190936Z",
     "start_time": "2020-06-22T19:03:22.118980Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6089706667314586"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# True value\n",
    "theta.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:03:30.561613Z",
     "start_time": "2020-06-22T19:03:30.501630Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.6611532131769402, 0.013922622951893662)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2SLS estimate\n",
    "iv_fit = IVRegressor()\n",
    "iv_fit.fit(X, treatment_sim, y_sim, w)\n",
    "ate, ate_sd = iv_fit.predict()\n",
    "(ate, ate_sd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:07:04.924847Z",
     "start_time": "2020-06-22T19:07:04.875455Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.7501211540497275, 0.012800163754977008)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# OLS estimate\n",
    "ols_fit=sm.OLS(y_sim, sm.add_constant(np.c_[treatment_sim, X], prepend=False)).fit()\n",
    "(ols_fit.params[0], ols_fit.bse[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pure Synthetic Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data generation mechanism is described in Hong et al \"*Semiparametric Efficiency in Nonlinear LATE Models*\" (2010)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:25:40.879421Z",
     "start_time": "2020-06-22T19:25:40.825352Z"
    }
   },
   "outputs": [],
   "source": [
    "def synthetic_data(n=10000, random_seed=None):\n",
    "    np.random.seed(random_seed)\n",
    "    gamma0 = -0.5\n",
    "    gamma1 = 1.0\n",
    "    delta = 1.0\n",
    "    x = np.random.uniform(size=n)\n",
    "    v = np.random.normal(size=n)\n",
    "    d1 = (gamma0 + x*gamma1 + delta + v>=0).astype(float)\n",
    "    d0 = (gamma0 + x*gamma1 + v>=0).astype(float)\n",
    "    \n",
    "    alpha = 1.0\n",
    "    beta = 0.5\n",
    "    lambda11 = 2.0\n",
    "    lambda00 = 1.0\n",
    "    xi1 = np.random.poisson(np.exp(alpha+x*beta))\n",
    "    xi2 = np.random.poisson(np.exp(x*beta))\n",
    "    xi3 = np.random.poisson(np.exp(lambda11), size=n)\n",
    "    xi4 = np.random.poisson(np.exp(lambda00), size=n)\n",
    "    \n",
    "    y1 = xi1 + xi3 * ((d1==1) & (d0==1)) + xi4 * ((d1==0) & (d0==0))\n",
    "    y0 = xi2 + xi3 * ((d1==1) & (d0==1)) + xi4 * ((d1==0) & (d0==0))\n",
    "    \n",
    "    z = np.random.binomial(1, stats.norm.cdf(x))\n",
    "    d = d1*z + d0*(1-z)\n",
    "    y = y1*d + y0*(1-d)\n",
    "    \n",
    "    return y, x, d, z, y1[(d1>d0)].mean()-y0[(d1>d0)].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:25:45.345780Z",
     "start_time": "2020-06-22T19:25:45.287971Z"
    }
   },
   "outputs": [],
   "source": [
    "y, x, d, z, late = synthetic_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:26:26.235823Z",
     "start_time": "2020-06-22T19:26:26.191466Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.1789099526066353"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# True value\n",
    "late"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:26:48.455795Z",
     "start_time": "2020-06-22T19:26:48.402675Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2.1900472390231775, 0.2623695460540134)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2SLS estimate\n",
    "iv_fit = IVRegressor()\n",
    "iv_fit.fit(x, d, y, z)\n",
    "ate, ate_sd = iv_fit.predict()\n",
    "(ate, ate_sd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-22T19:27:33.354806Z",
     "start_time": "2020-06-22T19:27:33.307105Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5.3482879532439975, 0.09201397327077365)"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# OLS estimate\n",
    "ols_fit=sm.OLS(y, sm.add_constant(np.c_[d, x], prepend=False)).fit()\n",
    "(ols_fit.params[0], ols_fit.bse[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
